{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.hparams import load_hparams_json\n",
    "from utils.util import intersperse\n",
    "import json\n",
    "from models.synthesizer.models.vits import Vits\n",
    "import torch\n",
    "import numpy as np\n",
    "import IPython.display as ipd\n",
    "from models.synthesizer.utils.symbols import symbols\n",
    "from models.synthesizer.utils.text import text_to_sequence\n",
    "\n",
    "\n",
    "hps = load_hparams_json(\"data/ckpt/synthesizer/vits5/config.json\")\n",
    "print(hps.train)\n",
    "model = Vits(\n",
    "    len(symbols),\n",
    "    hps[\"data\"][\"filter_length\"] // 2 + 1,\n",
    "    hps[\"train\"][\"segment_size\"] // hps[\"data\"][\"hop_length\"],\n",
    "    n_speakers=hps[\"data\"][\"n_speakers\"],\n",
    "    **hps[\"model\"])\n",
    "_ = model.eval()\n",
    "device = torch.device(\"cpu\")\n",
    "checkpoint = torch.load(str(\"data/ckpt/synthesizer/vits5/G_56000.pth\"), map_location=device)\n",
    "if \"model_state\" in checkpoint:\n",
    "    state = checkpoint[\"model_state\"]\n",
    "else:\n",
    "    state = checkpoint[\"model\"]\n",
    "model.load_state_dict(state, strict=False)\n",
    "\n",
    "# 随机抽取情感参考音频的根目录\n",
    "random_emotion_root = \"D:\\\\audiodata\\\\SV2TTS\\\\synthesizer\\\\emo\\\\\"\n",
    "import random, re\n",
    "from pypinyin import lazy_pinyin, Style\n",
    "\n",
    "import os\n",
    "\n",
    "def tts(txt, emotion, sid=0):\n",
    "    txt = \" \".join(lazy_pinyin(txt, style=Style.TONE3, neutral_tone_with_five=False))\n",
    "    text_norm = text_to_sequence(txt, hps[\"data\"][\"text_cleaners\"])\n",
    "    # if hps[\"data\"][\"add_blank\"]:\n",
    "    # text_norm = intersperse(text_norm, 0)\n",
    "    stn_tst = torch.LongTensor(text_norm)\n",
    "\n",
    "    with torch.no_grad(): #inference mode\n",
    "        x_tst = stn_tst.unsqueeze(0)\n",
    "        x_tst_lengths = torch.LongTensor([stn_tst.size(0)])\n",
    "        sid = torch.LongTensor([sid])\n",
    "        if emotion.endswith(\"wav\"):\n",
    "            from models.synthesizer.preprocess_audio import extract_emo\n",
    "            import librosa\n",
    "            wav, sr = librosa.load(emotion, 16000)\n",
    "            emo = torch.FloatTensor(extract_emo(np.expand_dims(wav, 0), sr, embeddings=True))\n",
    "        elif emotion == \"random_sample\":\n",
    "            rand_emo = random.sample(os.listdir(random_emotion_root), 1)[0]\n",
    "            print(rand_emo)\n",
    "            emo = torch.FloatTensor(np.load(f\"{random_emotion_root}\\\\{rand_emo}\")).unsqueeze(0)\n",
    "        elif emotion.endswith(\"npy\"):\n",
    "            print(emotion)\n",
    "            emo = torch.FloatTensor(np.load(f\"{random_emotion_root}\\\\{emotion}\")).unsqueeze(0)\n",
    "        else:\n",
    "            print(\"emotion参数不正确\")\n",
    "\n",
    "        audio = model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.8, length_scale=1, emo=emo)[0][0,0].data.float().numpy()\n",
    "    ipd.display(ipd.Audio(audio, rate=hps[\"data\"][\"sampling_rate\"], normalize=False))\n",
    "\n",
    "\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "推理："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "txt = \"我们将其拓展到文本驱动数字人形象领域\"\n",
    "#正常: \n",
    "tts(txt, emotion='emo-T0055G4906S0052.wav_00.npy', sid=100)\n",
    "#快速：emo-T0055G2323S0179.wav_00.npy\n",
    "\n",
    "#难过：\n",
    "tts(txt, emotion='emo-15_4581_20170825202626.wav_00.npy', sid=100)\n",
    "\n",
    "#开心：T0055G2412S0498.wav\n",
    "tts(txt, emotion='emo-T0055G2412S0498.wav_00.npy', sid=100)\n",
    "\n",
    "#愤怒 T0055G1371S0363.wav T0055G1344S0160.wav\n",
    "tts(txt, emotion='emo-T0055G1344S0160.wav_00.npy', sid=100)\n",
    "\n",
    "#疲惫\n",
    "tts(txt, emotion='emo-T0055G2294S0476.wav_00.npy', sid=100)\n",
    "\n",
    "#着急\n",
    "tts(txt, emotion='emo-T0055G1671S0170.wav_00.npy', sid=100)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "txt = \"我们将其拓展到文本驱动数字人形象领域\"\n",
    "tts(txt, emotion='random_sample', sid=100)\n",
    "tts(txt, emotion='random_sample', sid=100)\n",
    "tts(txt, emotion='random_sample', sid=100)\n",
    "tts(txt, emotion='random_sample', sid=100)\n",
    "tts(txt, emotion='random_sample', sid=100)\n",
    "tts(txt, emotion='random_sample', sid=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "txt = \"我们将其拓展到文本驱动数字人形象领域\"\n",
    "types = [\"平淡\", \"激动\", \"疲惫\", \"兴奋\", \"沮丧\", \"开心\"]\n",
    "for t in types:\n",
    "    print(t)\n",
    "    tts(txt, emotion=f'C:\\\\Users\\\\babys\\\\Music\\\\{t}.wav', sid=100)\n",
    "# tts(txt, emotion='D:\\\\audiodata\\\\aidatatang_200zh\\\\corpus\\\\train\\\\G1858\\\\T0055G1858S0342.wav', sid=5)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "预处理："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from models.synthesizer.preprocess import preprocess_dataset\n",
    "from pathlib import Path\n",
    "from utils.hparams import HParams\n",
    "datasets_root = Path(\"../audiodata/\")\n",
    "hparams = HParams(\n",
    "        n_fft = 1024, # filter_length\n",
    "        num_mels = 80,\n",
    "        hop_size = 256,                             # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)\n",
    "        win_size = 1024,                             # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)\n",
    "        fmin = 55,\n",
    "        min_level_db = -100,\n",
    "        ref_level_db = 20,\n",
    "        max_abs_value = 4.,                         # Gradient explodes if too big, premature convergence if too small.\n",
    "        sample_rate = 16000,\n",
    "        rescale = True,\n",
    "        max_mel_frames = 900,\n",
    "        rescaling_max = 0.9,        \n",
    "        preemphasis = 0.97,                         # Filter coefficient to use if preemphasize is True\n",
    "        preemphasize = True,\n",
    "        ### Mel Visualization and Griffin-Lim\n",
    "        signal_normalization = True,\n",
    "\n",
    "        utterance_min_duration = 1.6,               # Duration in seconds below which utterances are discarded\n",
    "        ### Audio processing options\n",
    "        fmax = 7600,                                # Should not exceed (sample_rate // 2)\n",
    "        allow_clipping_in_normalization = True,     # Used when signal_normalization = True\n",
    "        clip_mels_length = True,                    # If true, discards samples exceeding max_mel_frames\n",
    "        use_lws = False,                            # \"Fast spectrogram phase recovery using local weighted sums\"\n",
    "        symmetric_mels = True,                      # Sets mel range to [-max_abs_value, max_abs_value] if True,\n",
    "                                                    #               and [0, max_abs_value] if False\n",
    "        trim_silence = False,                        # Use with sample_rate of 16000 for best results\n",
    "\n",
    ")\n",
    "preprocess_dataset(datasets_root=datasets_root, \n",
    "        out_dir=datasets_root.joinpath(\"SV2TTS\", \"synthesizer\"),\n",
    "        n_processes=8,\n",
    "        skip_existing=True, \n",
    "        hparams=hparams, \n",
    "        no_alignments=False, \n",
    "        dataset=\"aidatatang_200zh\", \n",
    "        emotion_extract=True)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from models.synthesizer.train_vits import run\n",
    "from pathlib import Path\n",
    "from utils.hparams import HParams\n",
    "import torch, os\n",
    "import torch.multiprocessing as mp\n",
    "\n",
    "datasets_root = Path(\"../audiodata/SV2TTS/synthesizer\")\n",
    "hparams= HParams(\n",
    "  model_dir = \"data/ckpt/synthesizer/vits\",\n",
    ")\n",
    "hparams.loadJson(Path(hparams.model_dir).joinpath(\"config.json\"))\n",
    "hparams.data[\"training_files\"] = str(datasets_root.joinpath(\"train.txt\"))\n",
    "hparams.data[\"validation_files\"] = str(datasets_root.joinpath(\"train.txt\"))\n",
    "hparams.data[\"datasets_root\"] = str(datasets_root)\n",
    "\n",
    "n_gpus = torch.cuda.device_count()\n",
    "# for spawn\n",
    "os.environ['MASTER_ADDR'] = 'localhost'\n",
    "os.environ['MASTER_PORT'] = '8899'\n",
    "mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hparams))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "挑选只有对应emo文件的meta数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import os\n",
    "root = Path('../audiodata/SV2TTS/synthesizer')\n",
    "dict_info = []\n",
    "with open(root.joinpath(\"train.txt\"), \"r\", encoding=\"utf-8\") as dict_meta:\n",
    "    for raw in dict_meta:\n",
    "        if not raw:\n",
    "            continue\n",
    "        v = raw.split(\"|\")[0].replace(\"audio\",\"emo\")\n",
    "        emo_fpath = root.joinpath(\"emo\").joinpath(v)\n",
    "        if emo_fpath.exists():\n",
    "            dict_info.append(raw)\n",
    "        # else:\n",
    "        #     print(emo_fpath)\n",
    "# Iterate over each wav\n",
    "meta2 = Path('../audiodata/SV2TTS/synthesizer/train2.txt')\n",
    "metadata_file = meta2.open(\"w\", encoding=\"utf-8\")\n",
    "for new_info in dict_info:\n",
    "    metadata_file.write(new_info)\n",
    "metadata_file.close()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "从训练集中抽取10%作为测试集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "root = Path('../audiodata/SV2TTS/synthesizer')\n",
    "dict_info1 = []\n",
    "dict_info2 = []\n",
    "count = 1\n",
    "with open(root.joinpath(\"train.txt\"), \"r\", encoding=\"utf-8\") as dict_meta:\n",
    "    for raw in dict_meta:\n",
    "        if not raw:\n",
    "            continue\n",
    "        if count % 10 == 0:\n",
    "            dict_info2.append(raw)\n",
    "        else:\n",
    "            dict_info1.append(raw)\n",
    "        count += 1\n",
    "# Iterate over each wav\n",
    "meta1 = Path('../audiodata/SV2TTS/synthesizer/train1.txt')\n",
    "metadata_file = meta1.open(\"w\", encoding=\"utf-8\")\n",
    "for new_info in dict_info1:\n",
    "    metadata_file.write(new_info)\n",
    "metadata_file.close()\n",
    "\n",
    "meta2 = Path('../audiodata/SV2TTS/synthesizer/eval.txt')\n",
    "metadata_file = meta2.open(\"w\", encoding=\"utf-8\")\n",
    "for new_info in dict_info2:\n",
    "    metadata_file.write(new_info)\n",
    "metadata_file.close()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "root = Path('../audiodata/SV2TTS/synthesizer')\n",
    "spks = []\n",
    "spk_id = {}\n",
    "rows = []\n",
    "with open(root.joinpath(\"eval.txt\"), \"r\", encoding=\"utf-8\") as dict_meta:\n",
    "    for raw in dict_meta:\n",
    "        speaker_name = raw.split(\"-\")[1][6:10]\n",
    "        if speaker_name not in spk_id:\n",
    "            spks.append(speaker_name)\n",
    "            spk_id[speaker_name] = 1\n",
    "        rows.append(raw)\n",
    "i = 0\n",
    "spks.sort()\n",
    "\n",
    "for sp in spks:\n",
    "    spk_id[sp] = str(i)\n",
    "    i = i + 1\n",
    "print(len(spks))\n",
    "meta2 = Path('../audiodata/SV2TTS/synthesizer/eval2.txt')\n",
    "metadata_file = meta2.open(\"w\", encoding=\"utf-8\")\n",
    "for row in rows:\n",
    "    speaker_n = row.split(\"-\")[1][6:10]\n",
    "    metadata_file.write(row.strip()+\"|\"+spk_id[speaker_n]+\"\\n\")\n",
    "metadata_file.close()\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Not Recommended]\n",
    "Try to transcript map to detailed format:\n",
    "ni3 hao3 -> n i3 <pad> h ao3\n",
    "\n",
    "After couple of tests, I think this method will not improve the quality of result and may cause the crash of monotonic alignment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from pathlib import Path\n",
    "datasets_root = Path(\"../audiodata/SV2TTS/synthesizer/\")\n",
    "\n",
    "dictionary_fp = Path(\"../audiodata/ProDiff/processed/mandarin_pinyin.dict\")\n",
    "dict_map = {}\n",
    "for l in open(dictionary_fp, encoding='utf-8').readlines():\n",
    "    item = l.split(\"\\t\")\n",
    "    dict_map[item[0]] = item[1].replace(\"\\n\",\"\")\n",
    "\n",
    "with datasets_root.joinpath('train2.txt').open(\"w+\", encoding='utf-8') as f:\n",
    "    for l in open(datasets_root.joinpath('train.txt'), encoding='utf-8').readlines():\n",
    "        items = l.strip().replace(\"\\n\",\"\").replace(\"\\t\",\" \").split(\"|\")\n",
    "        phs_str = \"\"\n",
    "        for word in items[5].split(\" \"):\n",
    "            if word in dict_map:\n",
    "                phs_str += dict_map[word] \n",
    "            else:\n",
    "                phs_str += word\n",
    "            phs_str += \" _ \"\n",
    "        items[5] = phs_str\n",
    "        # if not os.path.exists(mfa_input_root.joinpath('train.txt')):\n",
    "        #     with open(mfa_input_root.joinpath(fileName + 'lab'), 'w+', encoding=\"utf-8\") as f:\n",
    "        f.write(\"|\".join(items) + \"\\n\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "预处理后的数据可视化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import librosa.display\n",
    "import librosa, torch\n",
    "import numpy as np\n",
    "from utils.audio_utils import spectrogram, mel_spectrogram, load_wav_to_torch, spec_to_mel\n",
    "\n",
    "# x, sr = librosa.load(\"D:\\audiodata\\SV2TTS\\synthesizer\\audio\\audio-T0055G2333S0196.wav_00.npy\")\n",
    "x = np.load(\"D:\\\\audiodata\\\\SV2TTS\\\\synthesizer\\\\audio\\\\audio-T0055G1858S0342.wav_00.npy\")\n",
    "\n",
    "plt.figure(figsize=(14, 5))\n",
    "librosa.display.waveplot(x)\n",
    "\n",
    "X = librosa.stft(x)\n",
    "Xdb = librosa.amplitude_to_db(abs(X))\n",
    "plt.figure(figsize=(14, 5))\n",
    "librosa.display.specshow(Xdb,  x_axis='time', y_axis='hz')\n",
    "\n",
    "# spectrogram = np.load(\"D:\\\\audiodata\\\\SV2TTS\\\\synthesizer\\\\mels\\\\mel-T0055G1858S0342.wav_00.npy\")\n",
    "audio = torch.from_numpy(x.astype(np.float32))\n",
    "\n",
    "# audio, sampling_rate = load_wav_to_torch(\"D:\\\\audiodata\\\\aidatatang_200zh\\\\corpus\\\\train\\\\G1858\\\\T0055G1858S0342.wav\")\n",
    "# audio_norm = audio / 32768.0\n",
    "audio_norm = audio.unsqueeze(0)\n",
    "spec = spectrogram(audio_norm, 1024, 256, 1024, center=False)\n",
    "# spec = spec_to_mel()\n",
    "spec = torch.squeeze(spec, 0)\n",
    "mel = spec_to_mel(spec, 1024, 80, 16000, 0, None)\n",
    "\n",
    "fig = plt.figure(figsize=(10, 8))\n",
    "ax2 = fig.add_subplot(211)\n",
    "im = ax2.imshow(mel, interpolation=\"none\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "情感聚类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# from sklearn import metrics\n",
    "# from sklearn.mixture import GaussianMixture  # 高斯混合模型\n",
    "import os\n",
    "import numpy as np\n",
    "import librosa\n",
    "import IPython.display as ipd\n",
    "from random import sample\n",
    "\n",
    "embs = []\n",
    "wavnames = []\n",
    "emo_root_path = \"D:\\\\audiodata\\\\SV2TTS\\\\synthesizer\\\\emo\\\\\"\n",
    "wav_root_path = \"D:\\\\audiodata\\\\aidatatang_200zh\\\\corpus\\\\train\\\\\"\n",
    "for idx, emo_fpath in enumerate(sample(os.listdir(emo_root_path), 10000)):\n",
    "    if emo_fpath.endswith(\".npy\") and emo_fpath.startswith(\"emo-T\"):\n",
    "        embs.append(np.expand_dims(np.load(emo_root_path + emo_fpath), axis=0))\n",
    "        wav_fpath = wav_root_path + emo_fpath[9:14] + \"\\\\\" + emo_fpath.split(\"_00\")[0][4:]\n",
    "        wavnames.append(wav_fpath)\n",
    "print(len(embs))\n",
    "\n",
    "\n",
    "x = np.concatenate(embs, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 聚类算法类的数量\n",
    "n_clusters = 20\n",
    "from sklearn.cluster import *\n",
    "# model = KMeans(n_clusters=n_clusters, random_state=10)\n",
    "# model = DBSCAN(eps=0.002, min_samples=2)\n",
    "# 可以自行尝试各种不同的聚类算法\n",
    "# model = Birch(n_clusters= n_clusters, threshold= 0.2)\n",
    "# model = SpectralClustering(n_clusters=n_clusters)\n",
    "model = AgglomerativeClustering(n_clusters= n_clusters)\n",
    "import random\n",
    "\n",
    "y_predict = model.fit_predict(x)\n",
    "\n",
    "def disp(wavname):\n",
    "    wav, sr =librosa.load(wavname, 16000)\n",
    "    display(ipd.Audio(wav, rate=sr))\n",
    "\n",
    "classes=[[] for i in range(y_predict.max()+1)]\n",
    "\n",
    "for idx, wavname in enumerate(wavnames):\n",
    "    classes[y_predict[idx]].append(wavname)\n",
    "\n",
    "for i in range(y_predict.max()+1):\n",
    "    print(\"类别:\", i, \"本类中样本数量:\", len(classes[i]))\n",
    "    \"\"\"每一个类只预览2条音频\"\"\"\n",
    "    for j in range(2):\n",
    "        idx = random.randint(0, len(classes[i]) - 1)\n",
    "        print(classes[i][idx])\n",
    "        disp(classes[i][idx])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "vscode": {
   "interpreter": {
    "hash": "788ab866da3baa6c99886d56abb59fe71b6a552bf52c65473ecf96c784704db8"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
